#----------------------------------------------------------------------
#  GFDM method test - 3d heat equation, Dirichlet BC, Implicit Euler
#  Author: Andrea Pavan
#  Date: 15/12/2022
#  License: GPLv3-or-later
#----------------------------------------------------------------------
using ElasticArrays;
using LinearAlgebra;
using SparseArrays;
using Printf;
using PyPlot;
include("utils.jl");


#problem definition
l1 = 5.0;       #domain x size
l2 = 2.5;       #domain y size
l3 = 2.0;       #domain z size
uL = 400;       #left border temperature
uR = 300;       #right border temperature
kcost = 5.0;        #thermal conductivity
rho = 0.8;      #density
ccost = 3.8;        #specific heat capacity
t0 = 0.0;       #starting time
tend = 20.0;     #ending time
dt = 0.2;       #time step
u0(x,y,z) = 0.0;        #initial conditions
uD(x,y,z) = uL+x*(uR-uL)/l1;        #Dirichlet boundary conditions

meshSize = 0.25;        #distance target between internal nodes
surfaceMeshSize = 0.25;        #distance target between boundary nodes
minNeighbors = 15;       #minimum number of neighbors allowed
minSearchRadius = meshSize/2;       #starting search radius


#pointcloud generation
time1 = time();
pointcloud = ElasticArray{Float64}(undef,3,0);      #3xN matrix containing the coordinates [X;Y;Z] of each node
boundaryNodes = Vector{Int};        #indices of the boundary nodes
normals = ElasticArray{Float64}(undef,3,0);     #3xN matrix containing the components [nx;ny;nz] of the normal of each boundary node
(section,sectionnormals) = defaultCrossSection(l2, l3, surfaceMeshSize);
for x in 0:surfaceMeshSize:l1
    append!(pointcloud, vcat(zeros(Float64,1,size(section,2)).+x,section));
    append!(normals, vcat(zeros(Float64,1,size(sectionnormals,2)),sectionnormals));
end
for y in -l2/2+surfaceMeshSize:surfaceMeshSize:l2/2-surfaceMeshSize
    for z in -l3+surfaceMeshSize:surfaceMeshSize:0-surfaceMeshSize
        if abs(y)<(l2-l3)/2 || (abs(y)-(l2-l3)/2)^2+(z+l3/2)^2<(l3/2)^2
            append!(pointcloud, [0,y,z]);
            append!(normals, [-1,0,0]);
            append!(pointcloud, [l1,y,z]);
            append!(normals, [1,0,0]);
        end
    end
end
boundaryNodes = collect(range(1,size(pointcloud,2)));
for y in -l2/2:meshSize:l2/2
    for z in -l3:meshSize:0
        if abs(y)<(l2-l3)/2 || (abs(y)-(l2-l3)/2)^2+(z+l3/2)^2<(l3/2)^2
            for x in 0+meshSize:meshSize:l1-meshSize
                #append!(pointcloud, [x,y,z]);
                append!(pointcloud, [x,y,z]+(rand(Float64,3).-0.5).*meshSize/5);
            end
        end
    end
end
#=(octree,octreeSize,octreeCenter,octreePoints,octreeNpoints) = buildOctree(pointcloud);
octreeLeaves = findall(octreeNpoints.>=0);
for i in octreeLeaves
    if octreeNpoints[i] != 1
        if (abs(octreeCenter[2,i])<(l2-l3)/2 && octreeCenter[3,i]<=0) || (abs(octreeCenter[2,i])-(l2-l3)/2)^2+(octreeCenter[3,i]+l3/2)^2<(l3/2)^2
            append!(pointcloud, octreeCenter[:,i]);
        end
    end
end=#
#=NinternalPoints = 0;
while NinternalPoints<1500
    x = rand(0:1e-6:l1);
    y = rand(-l2/2:1e-6:l2/2);
    z = rand(-l3:1e-6:0);
    if abs(y)<(l2-l3)/2 || (abs(y)-(l2-l3)/2)^2+(z+l3/2)^2<(l3/2)^2
        append!(pointcloud, [x,y,z]);
        global NinternalPoints += 1;
    end
end=#

internalNodes = collect(range(1+length(boundaryNodes),size(pointcloud,2)));
println("Generated pointcloud in ", round(time()-time1,digits=2), " s");
println("Pointcloud properties:");
println("  Boundary nodes: ",length(boundaryNodes));
println("  Internal nodes: ",length(internalNodes));
println("  Memory: ",memoryUsage(pointcloud,boundaryNodes));


#neighbor search
time2 = time();
N = size(pointcloud,2);     #number of nodes
(neighbors,Nneighbors,cell) = cartesianNeighborSearch(pointcloud,meshSize,minNeighbors);
#(neighbors,Nneighbors) = quadrantNeighborSearch(pointcloud,meshSize);
println("Found neighbors in ", round(time()-time2,digits=2), " s");
println("Connectivity properties:");
println("  Max neighbors: ",maximum(Nneighbors)," (at index ",findfirst(isequal(maximum(Nneighbors)),Nneighbors),")");
println("  Avg neighbors: ",round(sum(Nneighbors)/length(Nneighbors),digits=2));
println("  Min neighbors: ",minimum(Nneighbors)," (at index ",findfirst(isequal(minimum(Nneighbors)),Nneighbors),")");

#connectivity plot
figure(1);
#plot1Idx = rand(1:N,5);
plot1Idx = rand(1+length(boundaryNodes):N,5);
plot3D(pointcloud[1,:],pointcloud[2,:],pointcloud[3,:],marker=".",linestyle="None",color="lightgray");
for i in plot1Idx
    connColor = rand(3);
    plot3D(pointcloud[1,neighbors[i]],pointcloud[2,neighbors[i]],pointcloud[3,neighbors[i]],marker=".",linestyle="None",color=connColor);
    for j in neighbors[i]
        plot3D([pointcloud[1,i],pointcloud[1,j]],[pointcloud[2,i],pointcloud[2,j]],[pointcloud[3,i],pointcloud[3,j]],"-",color=connColor);
    end
end
plot3D(pointcloud[1,plot1Idx],pointcloud[2,plot1Idx],pointcloud[3,plot1Idx],"k.");
title("Connectivity plot");
axis("equal");
display(gcf());


#neighbors distances and weights
time3 = time();
P = Vector{Array{Float64}}(undef,N);        #relative positions of the neighbors
r2 = Vector{Vector{Float64}}(undef,N);      #relative distances of the neighbors
w2 = Vector{Vector{Float64}}(undef,N);      #neighbors weights
for i=1:N
    P[i] = Array{Float64}(undef,3,Nneighbors[i]);
    r2[i] = Vector{Float64}(undef,Nneighbors[i]);
    w2[i] = Vector{Float64}(undef,Nneighbors[i]);
    for j=1:Nneighbors[i]
        P[i][:,j] = pointcloud[:,neighbors[i][j]]-pointcloud[:,i];
        r2[i][j] = P[i][:,j]'P[i][:,j];
    end
    r2max = maximum(r2[i]);
    for j=1:Nneighbors[i]
        w2[i][j] = exp(-1*r2[i][j]/r2max)^2;
    end
end


#least square matrix inversion
A = Vector{Matrix}(undef,N);        #least-squares matrices
condA = Vector{Float64}(undef,N);       #condition number
#invA = Vector{Matrix}(undef,N);     #inverse matrices
B = Vector{Matrix}(undef,N);        #least-squares decomposition matrices
C = Vector{Matrix}(undef,N);        #derivatives coefficients matrices
for i in internalNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    zj = P[i][3,:];
    A[i] = [sum((w2[i]).*(xj.^2))      sum(xj.*yj.*(w2[i]))       sum(xj.*zj.*(w2[i]))       sum((w2[i]).*(xj.^3))           sum(xj.*(w2[i]).*(yj.^2))       sum(xj.*(w2[i]).*(zj.^2))       sum(yj.*(w2[i]).*(xj.^2))       sum(zj.*(w2[i]).*(xj.^2))       sum(xj.*yj.*zj.*(w2[i]));
            sum(xj.*yj.*(w2[i]))       sum((w2[i]).*(yj.^2))      sum(yj.*zj.*(w2[i]))       sum(yj.*(w2[i]).*(xj.^2))       sum((w2[i]).*(yj.^3))           sum(yj.*(w2[i]).*(zj.^2))       sum(xj.*(w2[i]).*(yj.^2))       sum(xj.*yj.*zj.*(w2[i]))        sum(zj.*(w2[i]).*(yj.^2));
            sum(xj.*zj.*(w2[i]))       sum(yj.*zj.*(w2[i]))       sum((w2[i]).*(zj.^2))      sum(zj.*(w2[i]).*(xj.^2))       sum(zj.*(w2[i]).*(yj.^2))       sum((w2[i]).*(zj.^3))           sum(xj.*yj.*zj.*(w2[i]))        sum(xj.*(w2[i]).*(zj.^2))       sum(yj.*(w2[i]).*(zj.^2));
            sum((w2[i]).*(xj.^3))      sum(yj.*(w2[i]).*(xj.^2))  sum(zj.*(w2[i]).*(xj.^2))  sum((w2[i]).*(xj.^4))           sum((w2[i]).*(xj.^2).*(yj.^2))  sum((w2[i]).*(xj.^2).*(zj.^2))  sum(yj.*(w2[i]).*(xj.^3))       sum(zj.*(w2[i]).*(xj.^3))       sum(yj.*zj.*(w2[i]).*(xj.^2));
            sum(xj.*(w2[i]).*(yj.^2))  sum((w2[i]).*(yj.^3))      sum(zj.*(w2[i]).*(yj.^2))  sum((w2[i]).*(xj.^2).*(yj.^2))  sum((w2[i]).*(yj.^4))           sum((w2[i]).*(yj.^2).*(zj.^2))  sum(xj.*(w2[i]).*(yj.^3))       sum(xj.*zj.*(w2[i]).*(yj.^2))   sum(zj.*(w2[i]).*(yj.^3));
            sum(xj.*(w2[i]).*(zj.^2))  sum(yj.*(w2[i]).*(zj.^2))  sum((w2[i]).*(zj.^3))      sum((w2[i]).*(xj.^2).*(zj.^2))  sum((w2[i]).*(yj.^2).*(zj.^2))  sum((w2[i]).*(zj.^4))           sum(xj.*yj.*(w2[i]).*(zj.^2))   sum(xj.*(w2[i]).*(zj.^3))       sum(yj.*(w2[i]).*(zj.^3));
            sum(yj.*(w2[i]).*(xj.^2))  sum(xj.*(w2[i]).*(yj.^2))  sum(xj.*yj.*zj.*(w2[i]))   sum(yj.*(w2[i]).*(xj.^3))       sum(xj.*(w2[i]).*(yj.^3))       sum(xj.*yj.*(w2[i]).*(zj.^2))   sum((w2[i]).*(xj.^2).*(yj.^2))  sum(yj.*zj.*(w2[i]).*(xj.^2))   sum(xj.*zj.*(w2[i]).*(yj.^2));
            sum(zj.*(w2[i]).*(xj.^2))  sum(xj.*yj.*zj.*(w2[i]))   sum(xj.*(w2[i]).*(zj.^2))  sum(zj.*(w2[i]).*(xj.^3))       sum(xj.*zj.*(w2[i]).*(yj.^2))   sum(xj.*(w2[i]).*(zj.^3))       sum(yj.*zj.*(w2[i]).*(xj.^2))   sum((w2[i]).*(xj.^2).*(zj.^2))  sum(xj.*yj.*(w2[i]).*(zj.^2));
            sum(xj.*yj.*zj.*(w2[i]))   sum(zj.*(w2[i]).*(yj.^2))  sum(yj.*(w2[i]).*(zj.^2))  sum(yj.*zj.*(w2[i]).*(xj.^2))   sum(zj.*(w2[i]).*(yj.^3))       sum(yj.*(w2[i]).*(zj.^3))       sum(xj.*zj.*(w2[i]).*(yj.^2))   sum(xj.*yj.*(w2[i]).*(zj.^2))   sum((w2[i]).*(yj.^2).*(zj.^2))];
    #=V = Matrix{Float64}(undef,Nneighbors[i]+1,9);
    V[1,:] = [0, 0, 0, 0, 0, 0, 0, 0, 0];
    for j=1:Nneighbors[i]
        V[j+1,:] = [xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, xj[j]*yj[j], xj[j]*zj[j], yj[j]*zj[j]];
    end
    W = Diagonal(vcat(1,w2[i]));=#
    condA[i] = cond(A[i]);
    B[i] = zeros(Float64,9,1+Nneighbors[i]);
    B[i][:,1] = [-sum((w2[i]).*xj); -sum((w2[i]).*yj); -sum((w2[i]).*zj); -sum((w2[i]).*(xj.^2)); -sum((w2[i]).*(yj.^2)); -sum((w2[i]).*(zj.^2)); -sum((w2[i]).*xj.*yj); -sum((w2[i]).*xj.*zj); -sum((w2[i]).*yj.*zj)];
    B[i][1,2:end] = w2[i].*xj;
    B[i][2,2:end] = w2[i].*yj;
    B[i][3,2:end] = w2[i].*zj;
    B[i][4,2:end] = w2[i].*(xj.^2);
    B[i][5,2:end] = w2[i].*(yj.^2);
    B[i][6,2:end] = w2[i].*(zj.^2);
    B[i][7,2:end] = w2[i].*xj.*yj;
    B[i][8,2:end] = w2[i].*xj.*zj;
    B[i][9,2:end] = w2[i].*yj.*zj;
    #C[i] = invA[i]*B[i];
    #(L,U) = cholesky(A[i]);
    #C[i] = inv(U)*inv(L)*B[i];
    (Q,R) = qr(A[i]);
    C[i] = inv(R)*transpose(Q)*B[i];
end
println("Inverted least-squares matrices in ", round(time()-time3,digits=2), " s");
println("Matrices properties:");
println("  Max condition number: ",round(maximum(condA[internalNodes]),digits=2));
println("  Avg condition number: ",round(sum(condA[internalNodes])/length(internalNodes),digits=2));
println("  Min condition number: ",round(minimum(condA[internalNodes]),digits=2));


#star quality
starquality = Vector{Float64}(undef,N);      #error in satisfying the maximum principle
for i in internalNodes
    #Liszka definition
    acenter = C[i][4,1]+C[i][5,1]+C[i][6,1];
    starquality[i] = 1.0;
    for j=1:lastindex(neighbors[i])
        starquality[i] -= abs(C[i][4,1+j]+C[i][5,1+j]+C[i][6,1+j])/acenter;
    end
    starquality[i] = abs(starquality[i]);
end
println("Star properties:");
println("  Max error: ",round(maximum(starquality[internalNodes]),digits=2)," (at index ",findfirst(isequal(maximum(starquality[internalNodes])),starquality[internalNodes]),")");
println("  Avg error: ",round(sum(starquality[internalNodes])/length(internalNodes),digits=2));
println("  Min error: ",round(minimum(starquality[internalNodes]),digits=2)," (at index ",findfirst(isequal(minimum(starquality[internalNodes])),starquality[internalNodes]),")");

#star quality plot
figure(2);
scatter3D(pointcloud[1,internalNodes],pointcloud[2,internalNodes],pointcloud[3,internalNodes],c=log10.(starquality[internalNodes]),cmap="viridis");
title("Star quality");
axis("equal");
display(gcf());


#crank-nicolson matrix assembly
time4 = time();
rows = Int[];
cols = Int[];
vals = Float64[];
for i in boundaryNodes
    push!(rows, i);
    push!(cols, i);
    push!(vals, 1);
end
for i in internalNodes
    push!(rows, i);
    push!(cols, i);
    push!(vals, 1/dt - (C[i][4,1]+C[i][5,1]+C[i][6,1])*2*kcost/(rho*ccost));
    for j=1:lastindex(neighbors[i])
        push!(rows, i);
        push!(cols, neighbors[i][j]);
        push!(vals, -(C[i][4,1+j]+C[i][5,1+j]+C[i][6,1+j])*2*kcost/(rho*ccost));
    end
end
M = sparse(rows,cols,vals,N,N);
println("Completed matrix assembly in ", round(time()-time4,digits=2), " s");

#matrix plot
figure(3);
spy(M);
title("Implicit Euler matrix");
axis("equal");
display(gcf());


#time propagation
time5 = time();
t = collect(t0:dt:tend);        #timesteps
u = u0.(pointcloud[1,:],pointcloud[2,:],pointcloud[3,:]);       #numerical solution
uprev = copy(u);        #numerical solution at the previous timestep
ue = uD.(pointcloud[1,:],pointcloud[2,:],pointcloud[3,:]);      #exact analytical solution
erru = ue-u;        #numerical solution error
rmseu = sqrt(erru'*erru)/sqrt(length(erru));        #root mean square of the numerical error
@printf("%6s | %6s | %12s | %12s\n","Step","Time","max(err(u))","RMSE(u)");
@printf("%6i | %6.2f | %12.4e | %12.4e\n",0,t0,maximum(abs.(erru)),rmseu);
for ti=2:lastindex(t)
    #implicit euler method
    global uprev = copy(u);
    b = zeros(N);       #rhs vector
    for i in boundaryNodes
        b[i] = uD(pointcloud[1,i],pointcloud[2,i],pointcloud[3,i]);
    end
    for i in internalNodes
        b[i] = uprev[i]/dt;
    end
    global u = M\b;

    #error calculation
    global erru = ue-u;        #numerical solution error
    global rmseu = sqrt(erru'*erru)/sqrt(length(erru));
    maxerru = maximum(abs.(erru));
    @printf("%6i | %6.2f | %12.4e | %12.4e\n",ti,t[ti],maxerru,rmseu);
    if maxerru>1e5
        println("ERROR: the solution diverged");
        break;
    end
end
println("Time integration completed in ", round(time()-time5,digits=2), " s");

#solution plot
figure(4);
scatter3D(pointcloud[1,internalNodes],pointcloud[2,internalNodes],pointcloud[3,internalNodes],c=u[internalNodes],cmap="inferno");
title("Numerical solution");
axis("equal");
display(gcf());
